This script analyzes filtered mAb escape data¶
In [1]:
# this cell is tagged as parameters for `papermill` parameterization
HENV103_filter = None
HENV117_filter = None
HENV26_filter = None
HENV32_filter = None
m102_filter = None
nAH1_filter = None
altair_config = None
nipah_config = None
escape_bubble_plot = None
bubble_1_mut_plot = None
mab_line_escape_plot = None
mab_plot_top = None
mab_plot_all = None
In [2]:
# Parameters
nipah_config = "nipah_config.yaml"
altair_config = "data/custom_analyses_data/theme.py"
HENV103_filter = "results/filtered_data/HENV103_escape_filtered.csv"
HENV117_filter = "results/filtered_data/HENV117_escape_filtered.csv"
HENV26_filter = "results/filtered_data/HENV26_escape_filtered.csv"
HENV32_filter = "results/filtered_data/HENV32_escape_filtered.csv"
m102_filter = "results/filtered_data/m102_escape_filtered.csv"
nAH1_filter = "results/filtered_data/nAH1_escape_filtered.csv"
escape_bubble_plot = "results/images/escape_bubble_plot.html"
bubble_1_mut_plot = "results/images/escape_bubble_1_mut_plot.html"
overlap_escape_plot = "results/images/overlap_escape_plot.html"
mab_line_escape_plot = "results/images/mab_line_escape_plot.html"
mab_plot_top = "results/images/mab_plot_top.html"
mab_plot_all = "results/images/mab_plot_all.html"
In [3]:
import math
import os
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import Bio.SeqIO
import yaml
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'
from Bio import PDB
import dmslogo
from dmslogo.colorschemes import CBPALETTE
from dmslogo.colorschemes import ValueToColorMap
In [4]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()
if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
pass
print("Already in correct directory")
else:
os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
print("Setup in correct directory")
Setup in correct directory
In [5]:
#altair_config = 'data/custom_analyses_data/theme.py'
#nipah_config = 'nipah_config.yaml'
#
#HENV103_filter = 'results/filtered_data/HENV103_escape_filtered.csv'
#HENV117_filter = 'results/filtered_data/HENV117_escape_filtered.csv'
#HENV26_filter = 'results/filtered_data/HENV26_escape_filtered.csv'
#HENV32_filter = 'results/filtered_data/HENV32_escape_filtered.csv'
#m102_filter = 'results/filtered_data/m102_escape_filtered.csv'
#nAH1_filter = 'results/filtered_data/nAH1_escape_filtered.csv'
#
#escape_bubble_plot = 'results/images/escape_bubble_plot.html'
#bubble_1_mut_plot = 'results/images/escape_bubble_1_mut_plot.html'
#overlap_escape_plot = 'results/images/overlap_escape_plot.html'
#
#m102_heat = 'results/images/m102_heatmap.html'
#HENV26_heat = 'results/images/HENV26_heatmap.html'
#HENV32_heat = 'results/images/HENV32_heatmap.html'
#nAH1_heat = 'results/images/nAH1_heatmap.html'
#HENV117_heat = 'results/images/HENV117_heatmap.html'
#HENV103_heat = 'results/images/HENV103_heatmap.html'
In [6]:
if altair_config:
with open(altair_config, 'r') as file:
exec(file.read())
with open(nipah_config) as f:
config = yaml.safe_load(f)
Make logo plots¶
Filtering parameters¶
In [7]:
# Make a dataframe with all the mutants with low entry scores for masking later in script
func_scores_E3 = pd.read_csv('../Nipah_Malaysia_RBP_DMS/results/func_effects/averages/CHO_EFNB3_low_func_effects.csv')
func_scores_E3_low_effect = func_scores_E3[
(func_scores_E3['effect'] < config['min_func_effect_for_ab']) &
(func_scores_E3['times_seen'] > config['func_times_seen_cutoff']) &
(func_scores_E3['site'] != 603) &
(func_scores_E3['mutant'] != '-') &
(func_scores_E3['mutant'] != '*')
]
display(func_scores_E3_low_effect)
| site | wildtype | mutant | effect | effect_std | times_seen | n_selections | |
|---|---|---|---|---|---|---|---|
| 13 | 71 | Q | P | -3.506 | 0.00000 | 6.714 | 7 |
| 23 | 72 | N | C | -2.485 | 0.62820 | 5.429 | 7 |
| 34 | 72 | N | P | -3.545 | 0.00000 | 6.714 | 7 |
| 39 | 72 | N | V | -3.084 | 0.06917 | 6.000 | 7 |
| 55 | 73 | Y | P | -3.412 | 0.90800 | 3.833 | 6 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 10753 | 597 | I | S | -2.796 | 0.08909 | 3.000 | 7 |
| 10754 | 597 | I | T | -3.526 | 0.00000 | 6.857 | 7 |
| 10759 | 598 | P | C | -2.300 | 0.60270 | 4.571 | 7 |
| 10776 | 598 | P | W | -3.177 | 0.00000 | 8.143 | 7 |
| 10777 | 598 | P | Y | -2.057 | 0.20350 | 3.286 | 7 |
2711 rows × 7 columns
Read in filtered antibody escape files and combine.¶
In [8]:
HENV103 = pd.read_csv(HENV103_filter)
HENV117 = pd.read_csv(HENV117_filter)
HENV26 = pd.read_csv(HENV26_filter)
HENV32 = pd.read_csv(HENV32_filter)
m102 = pd.read_csv(m102_filter)
nAH1 = pd.read_csv(nAH1_filter)
# Combine all the individual filtered antibody escape files
combined_df = pd.concat([HENV103,HENV117,HENV26,HENV32,m102,nAH1])
combined_df = combined_df[['site','wildtype','mutant','mutation','effect','escape_median','escape_std','times_seen_ab','show_site','ab']]
display(combined_df)
# Make a separate dataframe that only has the top sites
filtered_df = combined_df.query('show_site == True')
filtered_df = filtered_df[filtered_df['escape_median'] >= config['min_escape_cutoff']]
display(filtered_df)
| site | wildtype | mutant | mutation | effect | escape_median | escape_std | times_seen_ab | show_site | ab | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 71 | Q | D | Q71D | -0.7886 | -0.002836 | 0.05669 | 3.000 | False | HENV-103 |
| 1 | 71 | Q | E | Q71E | 0.4129 | -0.044930 | 0.12250 | 3.333 | False | HENV-103 |
| 2 | 71 | Q | F | Q71F | -0.3917 | 0.026290 | 0.01273 | 2.333 | False | HENV-103 |
| 3 | 71 | Q | G | Q71G | -0.3752 | 0.012700 | 0.01622 | 3.000 | False | HENV-103 |
| 4 | 71 | Q | H | Q71H | -0.1068 | 0.015420 | 0.22880 | 2.667 | False | HENV-103 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 6911 | 602 | T | R | T602R | 0.5666 | -0.023040 | 0.06005 | 6.667 | False | nAH1.3 |
| 6912 | 602 | T | S | T602S | 0.2874 | 0.160600 | 0.24200 | 3.667 | False | nAH1.3 |
| 6913 | 602 | T | V | T602V | 0.4577 | 0.134100 | 0.11030 | 5.000 | False | nAH1.3 |
| 6914 | 602 | T | W | T602W | 0.5192 | 0.060960 | 0.13720 | 5.667 | False | nAH1.3 |
| 6915 | 602 | T | Y | T602Y | 0.5354 | 0.121700 | 0.11090 | 5.000 | False | nAH1.3 |
41447 rows × 10 columns
| site | wildtype | mutant | mutation | effect | escape_median | escape_std | times_seen_ab | show_site | ab | |
|---|---|---|---|---|---|---|---|---|---|---|
| 861 | 151 | P | A | P151A | -0.4827 | 0.3086 | 0.27500 | 4.000 | True | HENV-103 |
| 863 | 151 | P | G | P151G | -0.3962 | 0.4608 | 0.29980 | 4.333 | True | HENV-103 |
| 864 | 151 | P | I | P151I | -1.7310 | 0.3507 | 0.39990 | 2.667 | True | HENV-103 |
| 866 | 151 | P | L | P151L | -1.5310 | 0.3130 | 0.19900 | 9.333 | True | HENV-103 |
| 902 | 154 | K | I | K154I | -0.9768 | 0.5460 | 0.33750 | 5.333 | True | HENV-103 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 5948 | 519 | W | Q | W519Q | 0.5198 | 0.5431 | 0.17250 | 4.333 | True | nAH1.3 |
| 5949 | 519 | W | R | W519R | -0.2430 | 0.8282 | 0.21840 | 11.000 | True | nAH1.3 |
| 5950 | 519 | W | S | W519S | 0.2115 | 0.4010 | 0.07602 | 6.333 | True | nAH1.3 |
| 5962 | 520 | I | T | I520T | 0.1253 | 1.6290 | 1.27900 | 6.000 | True | nAH1.3 |
| 5964 | 520 | I | W | I520W | -1.6150 | 0.6892 | 0.16750 | 5.000 | True | nAH1.3 |
515 rows × 10 columns
In [9]:
def identify_escape_sites(df, ab):
subset = df[(df['ab'] == ab)]
unique_sites = list(subset['site'].unique())
return unique_sites
abs = ['HENV-26', 'HENV-103', 'HENV-32', 'HENV-117', 'm102.4', 'nAH1.3']
sites_dict = {} # Create an empty dictionary to store the results
for ab in abs:
sites_dict[ab] = identify_escape_sites(filtered_df, ab)
display(sites_dict) #need site dict for later
{'HENV-26': [166,
167,
171,
176,
204,
233,
257,
490,
491,
492,
494,
497,
501,
529,
530,
531,
589],
'HENV-103': [151,
154,
176,
205,
242,
258,
259,
260,
261,
264,
268,
273,
274,
275,
277],
'HENV-32': [151,
154,
176,
199,
200,
201,
205,
207,
268,
274,
275,
277,
509,
534,
556,
593,
596],
'HENV-117': [171,
172,
204,
208,
217,
257,
351,
555,
580,
582,
583,
586,
587,
588,
589],
'm102.4': [171,
172,
239,
243,
270,
305,
507,
532,
542,
555,
559,
577,
582,
586,
587,
588,
589],
'nAH1.3': [184,
185,
188,
189,
190,
447,
448,
450,
468,
516,
517,
518,
519,
520]}
Plot bubble chart showing mAb escape for individual mutants by functional score for both E2 or E3¶
In [10]:
order_ab = ['m102.4','HENV-26','HENV-117','HENV-103','HENV-32','nAH1.3']
def generate_chart(df):
variant_selector = alt.selection_point(
on="mouseover",
empty=False,
fields=["site"],
value=1
)
chart = alt.Chart(df).mark_point(filled=True, opacity=0.2).encode(
x=alt.X('ab:O',sort=order_ab, title='Antibody', axis=alt.Axis(labelAngle=-45,grid=True)),
y=alt.Y('effect:Q', title='Cell Entry of Top Escape', axis=alt.Axis(grid=True, tickCount=4,values=[0.5, 0, -0.5, -1, -1.5, -2])),
size=alt.Size('escape_median', legend=alt.Legend(title='Mean Escape By Mutation')),
xOffset='random:Q',
tooltip=['site','wildtype','mutant','ab', 'effect','escape_median','escape_std'],
color=alt.Color('ab').legend(None),
opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
strokeWidth=alt.condition(variant_selector,alt.value(2),alt.value(0))
).transform_calculate(
random="sqrt(-1*log(random()))*cos(2*PI*random())"
#random='random'
).properties(
height=300,
width=300
).add_params(variant_selector)
return chart
escape_bubble = generate_chart(filtered_df)
escape_bubble.display()
escape_bubble.save(escape_bubble_plot)
Now summarize by number of mutations between wildtype and mutant codons¶
In [11]:
# Load in wt nucleotide sequence (which is different than the 'wt' sequence from Library as it was codon optimized)
niv_m_wt = str(Bio.SeqIO.read('data/custom_analyses_data/alignments/wild_type_seq.fasta', 'fasta').seq)
codon_table = {
"ATA":"I", "ATC":"I", "ATT":"I", "ATG":"M",
"ACA":"T", "ACC":"T", "ACG":"T", "ACT":"T",
"AAC":"N", "AAT":"N", "AAA":"K", "AAG":"K",
"AGC":"S", "AGT":"S", "AGA":"R", "AGG":"R",
"CTA":"L", "CTC":"L", "CTG":"L", "CTT":"L",
"CCA":"P", "CCC":"P", "CCG":"P", "CCT":"P",
"CAC":"H", "CAT":"H", "CAA":"Q", "CAG":"Q",
"CGA":"R", "CGC":"R", "CGG":"R", "CGT":"R",
"GTA":"V", "GTC":"V", "GTG":"V", "GTT":"V",
"GCA":"A", "GCC":"A", "GCG":"A", "GCT":"A",
"GAC":"D", "GAT":"D", "GAA":"E", "GAG":"E",
"GGA":"G", "GGC":"G", "GGG":"G", "GGT":"G",
"TCA":"S", "TCC":"S", "TCG":"S", "TCT":"S",
"TTC":"F", "TTT":"F", "TTA":"L", "TTG":"L",
"TAC":"Y", "TAT":"Y", "TAA":"*", "TAG":"*",
"TGC":"C", "TGT":"C", "TGA":"*", "TGG":"W"
}
def find_closest_codon(wt_codon, mutant_aa):
mutant_codons = [codon for codon, aa in codon_table.items() if aa == mutant_aa]
min_mutations = 3 # Maximum mutations possible
closest_codon = None
for m_codon in mutant_codons:
mutations = sum([1 for c1, c2 in zip(wt_codon, m_codon) if c1 != c2])
if mutations < min_mutations:
min_mutations = mutations
closest_codon = m_codon
return closest_codon, min_mutations
# Function to extract codon for a given site
def extract_codon(site):
idx = (site - 1) * 3
return niv_m_wt[idx: idx + 3]
def extract_codon_niv_b(site):
idx = (site - 1) * 3
return niv_m_wt[idx: idx + 3]
def apply_codon_to_df(df,extract_func):
df['wt_codon'] = df['site'].apply(extract_func)
df['closest_mutant_codon'] = df.apply(lambda row: find_closest_codon(row['wt_codon'], row['mutant'])[0], axis=1)
df['min_mutations'] = df.apply(lambda row: find_closest_codon(row['wt_codon'], row['mutant'])[1], axis=1)
return df
combined_df = apply_codon_to_df(combined_df,extract_codon)
filtered_df = apply_codon_to_df(filtered_df,extract_codon)
In [12]:
def plot_escape_and_mutations_away(df):
variant_selector = alt.selection_point(
on="mouseover",
empty=False,
fields=["site"],
value=1
)
df = df[df['min_mutations'] == 1]
chart = alt.Chart(df).mark_point(filled=True,opacity=0.4).encode(
x=alt.X('ab:O',sort=order_ab,title=None,axis=alt.Axis(labelAngle=-45,grid=True)),
y=alt.Y('effect:Q', title='Cell Entry of Escape Mutants',axis=alt.Axis(grid=True,tickCount=4,values=[0.5, 0, -0.5, -1, -1.5, -2])), # 'Q' denotes a quantitative variable
size=alt.Size('escape_median',legend=alt.Legend(title='Escape of Mutant')),
xOffset='random:Q',
tooltip=['ab','effect','escape_median','site','mutant'],
opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
strokeWidth=alt.condition(variant_selector,alt.value(2),alt.value(0)),
color=alt.Color('ab:N').legend(None),
).transform_calculate(
#random='random()'
random="sqrt(-2*log(random()))*cos(2*PI*random())"
).properties(
height=300,
width=300
).add_params(variant_selector)
return chart
bubble_plot_1_mut_away = plot_escape_and_mutations_away(filtered_df)
bubble_plot_1_mut_away.display()
bubble_plot_1_mut_away.save(bubble_1_mut_plot)
In [13]:
def find_overlapping_escape(df):
slider = alt.binding_range(min=config['min_func_effect_for_ab'], max=1, step=0.25, name="effect")
selector = alt.param(name="SelectorName", value=-4, bind=slider)
radio = alt.binding_radio(options=[1, 2, 3], name='Min Mutations:')
mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)
df_filtered = df
# Group by 'site' and 'mutant', count the unique 'ab' values for each group
grouped = df_filtered.groupby(['site', 'mutant'])['ab'].nunique().reset_index()
# Filter groups where the count of unique 'ab' values is at least 2
result = grouped[grouped['ab'] >= 2]
# Merge the result with the original dataframe to get the full rows
df_result = pd.merge(df, result[['site', 'mutant']], on=['site', 'mutant'])
df_result['mutation_number'] = df_result['mutation'].str.extract('(\d+)').astype(int)
base = (
alt.Chart(df_result).mark_rect().encode(
x=alt.X('mutation:O', title='Site',sort=alt.EncodingSortField(field='mutation_number'), axis=alt.Axis(labelAngle=-90,grid=False)),
y=alt.Y('ab:O', title='Mutant',sort=order_ab,axis=alt.Axis(grid=False)), # Apply custom sort order here
color='escape_median',
#opacity=alt.condition(alt.datum.effect >= selector, alt.value(1), alt.value(0)),
tooltip=['site','wildtype','mutant','escape_median','min_mutations'],
).properties(
width=alt.Step(30),
height=alt.Step(20)
)
).add_params(selector,mutation_selector).transform_filter(
(alt.datum.effect >= selector) & (alt.datum.min_mutations == mutation_selector)
)
return base
overlap_escape = find_overlapping_escape(filtered_df)
overlap_escape.display()
overlap_escape.save(overlap_escape_plot)
Line plots of escape¶
In [14]:
def plot_line_escape(df):
variant_selector = alt.selection_point(
on="mouseover",
empty=False,
fields=["site"],
value=0
)
# Group by 'site' and 'mutant', count the unique 'ab' values for each group
summed = df.groupby(['site','ab'])['escape_median'].sum().reset_index()
empty_chart = []
ab_list = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
for idx, ab in enumerate(ab_list):
tmp_df = summed[summed['ab'] == ab]
#color = '#1f4e79'
if ab in ['m102.4','HENV-26','HENV-117']:
color = "#1f4e79"
if ab in ['HENV-103','HENV-32']:
color = "#ff7f0e"
if ab in ['nAH1.3']:
color = "#2ca02c"
# Conditionally set the x-axis labels and title for the last plot
is_last_plot = idx == len(ab_list) - 1
x_axis = alt.Axis(values=[100, 200, 300, 400, 500, 600], tickCount=6, labelAngle=-90, grid=True,
labelExpr="datum.value % 100 === 0 ? datum.value : ''",
title="Site" if is_last_plot else None,
labels=is_last_plot) # Only show labels for the last plot
base = (
alt.Chart(tmp_df).mark_line(size=1, color=color).encode(
x=alt.X('site:O', axis=x_axis),
y=alt.Y('escape_median', title=f'{ab}', axis=alt.Axis(grid=True,tickCount=3)),
).properties(
width=1000,
height=100
)
)
point = base.mark_point(color='black',size=10,filled=True).encode(
x=alt.X('site:O', axis=x_axis),
y=alt.Y('escape_median', title=f'{ab}', axis=alt.Axis(grid=True)),
size=alt.condition(variant_selector, alt.value(100),alt.value(15)),
color=alt.condition(variant_selector, alt.value('black'), alt.value(color)),
tooltip=['site','escape_median'],
).properties(
width=1000,
height=100,
).add_params(variant_selector)
chart = base + point
empty_chart.append(chart)
# Use configure_concat to adjust spacing between vertically concatenated plots
combined_chart = alt.vconcat(*empty_chart, spacing=1).resolve_scale(y='independent', x='shared', color='independent').properties(title='Summed Antibody Escape by Site and Colored by Epitope')
return combined_chart
tmp_line = plot_line_escape(combined_df)
tmp_line.display()
tmp_line.save(mab_line_escape_plot)
Now calculate atomic distances between escape sites and closest amino acid in heavy and light chains¶
In [15]:
def calculate_min_distances(pdb_path, source_chain_id, target_chain_ids, name):
# Initialize the PDB parser and load the structure
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure('structure_id', pdb_path)
source_chain = structure[0][source_chain_id]
target_chains = [structure[0][chain_id] for chain_id in target_chain_ids]
data = []
for residueA in source_chain:
if residueA.resname in ["HOH", "WAT", "IPA", "NAG"]:
continue
min_distance = float('inf')
closest_residueB = None
closest_chain_id = None
residues_within_4 = 0
for target_chain in target_chains:
for residueB in target_chain:
if residueB.resname in ["HOH", "WAT", "IPA"]:
continue
# Check for residues within 4 angstroms
is_within_4 = False
for atomA in residueA:
for atomB in residueB:
distance = atomA - atomB
if distance < min_distance:
min_distance = distance
closest_residueB = residueB
closest_chain_id = target_chain.get_id()
if distance < 4:
is_within_4 = True
if is_within_4:
residues_within_4 += 1
data.append({
'wildtype': residueA.resname,
'site': residueA.id[1],
'chain': closest_chain_id,
'residue': closest_residueB.id[1],
'residue_name': closest_residueB.resname,
'distance': min_distance,
'residues_within_4': residues_within_4,
'ab': name
})
# Convert data to pandas DataFrame
df = pd.DataFrame(data)
return df
def check_file(input_path,source_chain,target_chain,name,output_path):
file_path = output_path
if not os.path.exists(file_path):
print(f'File {name} does not exist, running calculation')
output_df = calculate_min_distances(input_path,source_chain,target_chain,name)
print(f'done calculating for {file_path}')
output_df.to_csv(output_path,index=False)
return output_df
else:
print("File already exists,loading from disk")
output_df = pd.read_csv(output_path)
return output_df
pdb_path_26 = 'data/custom_analyses_data/crystal_structures/6vy5.pdb'
source_chain_26 = 'A'
target_chains_26 = ['H', 'L']
output_path_26 = 'results/distances/df_HENV26_atomic_distances.csv'
pdb_path_32 = 'data/custom_analyses_data/crystal_structures/6vy4.pdb'
source_chain_32 = 'A'
target_chains_32 = ['H', 'L']
output_path_32 = 'results/distances/df_HENV32_atomic_distances.csv'
pdb_path_nah = 'data/custom_analyses_data/crystal_structures/7txz.pdb'
source_chain_nah = 'A'
target_chains_nah = ['F', 'E']
output_path_nah = 'results/distances/df_nAH_atomic_distances.csv'
pdb_path_m102 = 'data/custom_analyses_data/crystal_structures/6cmg.pdb'
source_chain_m102 = 'A'
target_chains_m102 = ['B', 'C']
output_path_m102 = 'results/distances/df_m102_atomic_distances.csv'
df_HENV26 = check_file(pdb_path_26, source_chain_26, target_chains_26, 'HENV-26',output_path_26)
df_HENV32 = check_file(pdb_path_32, source_chain_32, target_chains_32, 'HENV-32',output_path_32)
df_nah = check_file(pdb_path_nah, source_chain_nah, target_chains_nah, 'nAH1.3',output_path_nah)
df_nah['chain'].replace({'E': 'H', 'F': 'L'}, inplace=True) # Fix naming so consistent heavy and light chain naming
df_m102 = check_file(pdb_path_m102, source_chain_m102, target_chains_m102, 'm102.4',output_path_m102)
df_m102['chain'].replace({'C': 'H', 'B': 'L'}, inplace=True) # Fix naming so consistent heavy and light chain naming
File already exists,loading from disk File already exists,loading from disk File already exists,loading from disk File already exists,loading from disk
In [16]:
def find_close_mab_sites(df,name):
unique_sites = df.query('distance <= 4')['site'].unique()
mab_site_list = list(unique_sites)
print(f'Close sites for mAb {name} are: {mab_site_list}')
return mab_site_list
### First find RBP sites that are close to mAb residues
nah_close = find_close_mab_sites(df_nah,'nAH1.3')
HENV26_close = find_close_mab_sites(df_HENV26,'HENV-26')
HENV32_close = find_close_mab_sites(df_HENV32,'HENV-32')
m102_close = find_close_mab_sites(df_m102,'m102.4')
### Now combined the close residues AND the top escape sites identified previously
nah_combined_sites = sites_dict['nAH1.3'] + nah_close
HENV26_combined_sites = sites_dict['HENV-26'] + HENV26_close
HENV32_combined_sites = sites_dict['HENV-32'] + HENV32_close
m102_combined_sites = sites_dict['m102.4'] + m102_close
Close sites for mAb nAH1.3 are: [172, 183, 184, 185, 186, 187, 188, 190, 191, 358, 449, 450, 451, 472, 515, 516, 517, 518, 570] Close sites for mAb HENV-26 are: [389, 401, 403, 404, 458, 488, 489, 490, 491, 492, 494, 497, 501, 504, 505, 506, 528, 529, 530, 531, 532, 533, 555, 556, 557, 581, 586] Close sites for mAb HENV-32 are: [196, 199, 200, 201, 202, 203, 205, 206, 207, 210, 254, 256, 258, 260, 262, 263, 264, 266, 553] Close sites for mAb m102.4 are: [239, 240, 241, 242, 305, 458, 488, 489, 490, 504, 505, 506, 507, 530, 532, 533, 555, 557, 559, 579, 580, 581, 588]
In [17]:
def make_distance(df):
subset_df = df[['site','distance']].copy()
subset_df['mutant'] = 'distance'
subset_df['wildtype'] = ''
subset_df['effect'] = 'escape_median'
subset_df.rename(columns={'distance': 'value'}, inplace=True)
return subset_df
distance_nah_df = make_distance(df_nah)
distance_26_df = make_distance(df_HENV26)
distance_32_df = make_distance(df_HENV32)
distance_m102_df = make_distance(df_m102)
display(distance_m102_df)
| site | value | mutant | wildtype | effect | |
|---|---|---|---|---|---|
| 0 | 176 | 35.044434 | distance | escape_median | |
| 1 | 177 | 31.866014 | distance | escape_median | |
| 2 | 178 | 27.842815 | distance | escape_median | |
| 3 | 179 | 28.777035 | distance | escape_median | |
| 4 | 180 | 28.332012 | distance | escape_median | |
| ... | ... | ... | ... | ... | ... |
| 423 | 599 | 30.802711 | distance | escape_median | |
| 424 | 600 | 28.920950 | distance | escape_median | |
| 425 | 601 | 27.248772 | distance | escape_median | |
| 426 | 602 | 29.020868 | distance | escape_median | |
| 427 | 603 | 27.945621 | distance | escape_median |
428 rows × 5 columns
Prepare dataframes for heatmaps¶
In [18]:
def make_empty_df_with_distance(ab,distance_file):
print(ab)
sites = range(71, 603)
amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
# Create the combination of each site with each amino acid
data = [{'site': site, 'mutant': aa} for site in sites for aa in amino_acids]
# Create the DataFrame
empty_df = pd.DataFrame(data)
all_sites_df = pd.merge(empty_df,combined_df.query(f'ab == "{ab}"'),on=['site','mutant'],how='left')
df_melted = all_sites_df.melt(id_vars=['site', 'mutant', 'wildtype'],
value_vars=['escape_median'],
var_name='effect', value_name='value')
df_filtered = func_scores_E3_low_effect.melt(id_vars=['site', 'mutant', 'wildtype'],
value_vars=['effect'],
var_name='effect', value_name='value')
df_test = pd.concat([df_melted,df_filtered,distance_file],ignore_index=True)
df_test['ab'] = ab
return df_test
empty_df_m102 = make_empty_df_with_distance("m102.4",distance_m102_df)
empty_df_HENV26 = make_empty_df_with_distance("HENV-26",distance_26_df)
empty_df_HENV32 = make_empty_df_with_distance("HENV-32",distance_32_df)
empty_df_nah = make_empty_df_with_distance("nAH1.3",distance_nah_df)
def make_empty_df(ab):
sites = range(71, 603)
amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
# Create the combination of each site with each amino acid
data = [{'site': site, 'mutant': aa} for site in sites for aa in amino_acids]
# Create the DataFrame
empty_df = pd.DataFrame(data)
all_sites_df = pd.merge(empty_df,combined_df.query(f'ab == "{ab}"'),on=['site','mutant'],how='left')
df_melted = all_sites_df.melt(id_vars=['site', 'mutant', 'wildtype'],
value_vars=['escape_median'],
var_name='effect', value_name='value')
df_filtered = func_scores_E3_low_effect.melt(id_vars=['site', 'mutant', 'wildtype'],
value_vars=['effect'],
var_name='effect', value_name='value')
df_test = pd.concat([df_melted,df_filtered],ignore_index=True)
df_test['ab'] = ab
return df_test
empty_df_HENV117 = make_empty_df("HENV-117")
empty_df_HENV103 = make_empty_df("HENV-103")
combined_ab = pd.concat([empty_df_m102,empty_df_HENV26,empty_df_HENV32,empty_df_nah,empty_df_HENV117,empty_df_HENV103])
display(combined_ab)
m102.4 HENV-26 HENV-32 nAH1.3
| site | mutant | wildtype | effect | value | ab | |
|---|---|---|---|---|---|---|
| 0 | 71 | A | NaN | escape_median | NaN | m102.4 |
| 1 | 71 | C | Q | escape_median | 0.294900 | m102.4 |
| 2 | 71 | D | Q | escape_median | -0.064940 | m102.4 |
| 3 | 71 | E | Q | escape_median | 0.020610 | m102.4 |
| 4 | 71 | F | Q | escape_median | 0.006025 | m102.4 |
| ... | ... | ... | ... | ... | ... | ... |
| 13346 | 597 | S | I | effect | -2.796000 | HENV-103 |
| 13347 | 597 | T | I | effect | -3.526000 | HENV-103 |
| 13348 | 598 | C | P | effect | -2.300000 | HENV-103 |
| 13349 | 598 | W | P | effect | -3.177000 | HENV-103 |
| 13350 | 598 | Y | P | effect | -2.057000 | HENV-103 |
81836 rows × 6 columns
In [19]:
def plot_distance_only(df,trigger):
custom_order = ['distance','R','K','H','D','E','Q','N','S','T','Y','W','F','A','I','L','M','V','G','P','C']
all_residues = range(71, 603)
final_df = df
final_df = final_df.sort_values('site') # Sort the dataframe by 'site' to ensure that duplicates are detected correctly.
sort_order = {mutant: i for i, mutant in enumerate(custom_order)} # Create a dictionary that maps each mutant to its sort rank based on the custom order
final_df['mutant_rank'] = final_df['mutant'].map(sort_order) # Map the 'mutant' column to these ranks
final_df = final_df.sort_values('mutant_rank') # Now sort the dataframe by this rank
final_df = final_df.drop(columns=['mutant_rank']) # Drop the 'mutant_rank' column as it is no longer needed after sorting
sites = sorted(final_df['site'].unique(), key=lambda x: float(x))
ab_list = ['m102.4','HENV-26','HENV-117','HENV-103','HENV-32','nAH1.3']
empty_chart = [] #setup collection for charts
for idx, ab in enumerate(ab_list):
tmp_df = final_df[final_df['ab'] == ab]
if ab == 'm102.4':
site_subset = m102_combined_sites
#legend_conditional = alt.Legend(title='Distance to mAb')
if ab == 'HENV-26':
site_subset = HENV26_combined_sites
#legend_conditional = alt.Legend(title='Distance to mAb')
if ab == 'HENV-32':
site_subset = HENV32_combined_sites
#legend_conditional = alt.Legend(title='Distance to mAb')
if ab == 'HENV-103':
site_subset = sites_dict['HENV-103']
#legend_conditional = alt.Legend(title=None)
if ab == 'HENV-117':
site_subset = sites_dict['HENV-117']
#legend_conditional = alt.Legend(title=None)
if ab == 'nAH1.3':
site_subset = nah_combined_sites
#legend_conditional = alt.Legend(title='Distance to mAb')
#select which sites you will show
if trigger == True:
tmp_df = tmp_df[tmp_df['site'].isin(site_subset)]
x_axis = alt.Axis(labelAngle=-90,
#labelExpr="datum.value % 10 === 0 ? datum.value : ''",
title="Site")
else:
tmp_df = tmp_df[tmp_df['site'].isin(all_residues)]
# Conditionally set the x-axis labels and title for the last plot
is_last_plot = idx == len(ab_list) - 1
x_axis = alt.Axis(labelAngle=-90,
labelExpr="datum.value % 10 === 0 ? datum.value : ''",
title="Site" if is_last_plot else None,
labels=True) # Only show labels for the last plot
# Prepare the color scales separately for distance and effects
# Filter out 'distance' values before creating the effect heatmap
effect_df = tmp_df[(tmp_df['mutant'] != 'distance') & (tmp_df['effect'] != 'effect')]
max_color = effect_df['value'].max()
min_color = effect_df['value'].min()
#Adjust color scheme for abs with little sensitizing mutations
if min_color > -1:
min_color = min_color - 1
# Prepare the color scale for effects, Altair will automatically determine the domain
color_scale_escape = alt.Scale(scheme='redblue', domainMid=0,domain=[min_color,max_color])
color_scale_entropy = alt.Scale(scheme='purples', domain=[0, 15],reverse=True)
strokewidth_size = 0.25
unique_wildtypes_df = tmp_df.drop_duplicates(subset=['site', 'wildtype'])
# The chart for the heatmap
base = (
alt.Chart(tmp_df,title=f'{ab}')
.encode(
x=alt.X('site:O', title='Site', sort=sites, axis=x_axis),
y=alt.Y('mutant', title='Amino Acid', sort=alt.EncodingSortField(field='sort_order', order='ascending'),axis=alt.Axis(grid=False)), # Apply custom sort order here
tooltip=['site','wildtype','mutant','value'],
).properties(
width=alt.Step(10),
height=alt.Step(11)
)
)
# Heatmap for distance
chart_empty = (
base.mark_rect(color='#e6e7e8').encode(
).transform_filter(
alt.datum.effect == 'escape_median'
)
)
# Heatmap for effect
chart_effect = (
base.mark_rect(stroke='black',strokeWidth = strokewidth_size).encode(
color=alt.condition('datum.mutant != "distance"',
alt.Color('value:Q', scale=color_scale_escape,legend=alt.Legend(title=f'{ab} Escape')),
alt.value('transparent')),
).transform_filter(
alt.datum.effect == 'escape_median'
)
)
# Heatmap for distance
if ab in ['m102.4','HENV-26','HENV-32','nAH1.3']:
chart_distance = (
base.mark_rect().encode(
color=alt.condition('datum.mutant == "distance"',
alt.Color('value:Q', scale=color_scale_entropy,legend=alt.Legend(title='Distance to mAb')),
alt.value('transparent'))
).transform_filter(
alt.datum.effect == 'escape_median'
)
)
else:
chart_distance = (
base.mark_rect(color='transparent').encode(
#color=alt.Color('white'),
#alt.Color('value:Q', scale=color_scale_entropy,legend=alt.Legend(title='Distance to mAb')),
#alt.value('transparent'))
).transform_filter(
alt.datum.effect == 'escape_median'
)
)
# Heatmap for distance
chart_filtered = (
base.mark_rect(color='#939598',stroke='black',strokeWidth = strokewidth_size).encode(
).transform_filter(
alt.datum.effect == 'effect'
)
)
# The layer for the wildtype boxes
wildtype_layer_box = (
alt.Chart(unique_wildtypes_df).mark_rect(color='white',stroke='black',strokeWidth = strokewidth_size).encode(
x=alt.X('site:O', sort=sites),
y=alt.Y('wildtype', sort=alt.EncodingSortField(field='sort_order', order='ascending')),
opacity=alt.value(1)
)
.transform_filter(
(alt.datum.wildtype != '') & (alt.datum.wildtype != None) & (alt.datum.value != None)
)
)
# The layer for the wildtype amino acids
wildtype_layer = (
alt.Chart(unique_wildtypes_df).mark_text(color='black', text='X', size=8).encode(
x=alt.X('site:O', sort=sites),
y=alt.Y('wildtype', sort=alt.EncodingSortField(field='sort_order', order='ascending')),
opacity=alt.value(1)
)
.transform_filter(
(alt.datum.wildtype != '') & (alt.datum.wildtype != None) & (alt.datum.value != None)
)
)
# Combine the heatmap layer with the wildtype layer
chart = alt.layer(chart_empty,chart_effect,chart_distance,chart_filtered,wildtype_layer_box,wildtype_layer).resolve_scale(color='independent')
empty_chart.append(chart)
combined_chart = alt.vconcat(*empty_chart, spacing=1).resolve_scale(y='shared', x='independent', color='independent').configure_title(
anchor='start', # Aligns the title to the left ('middle' for center, 'end' for right)
offset=10, # Adjusts the distance of the title from the chart
orient='top', # Positions the title at the top; use 'bottom' to position at the bottom
)
return combined_chart
mab_plot = plot_distance_only(combined_ab,True)
mab_plot.display()
mab_plot.save(mab_plot_top)
In [20]:
mab_all = plot_distance_only(combined_ab,False)
mab_all.display()
mab_all.save(mab_plot_all)
In [ ]: